{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aggregation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from chemprop.nn.agg import MeanAggregation, SumAggregation, NormAggregation, AttentiveAggregation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is example output from [message passing](./message_passing.ipynb) for input to aggregation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_atoms_in_batch = 7\n",
    "hidden_dim = 3\n",
    "example_message_passing_output = torch.randn(n_atoms_in_batch, hidden_dim)\n",
    "which_atoms_in_which_molecule = torch.tensor([0, 0, 1, 1, 1, 1, 2]).long()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Combine nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The aggregation layer combines the node level represenations into a graph level representaiton (usually atoms -> molecule)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mean and sum aggregation "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Mean aggregation is recommended when the property to predict does not depend on the number of atoms in the molecules (intensive). Sum aggregation is recommended when the property is extensive, though usually norm aggregation is better."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_agg = MeanAggregation()\n",
    "sum_agg = SumAggregation()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.4593, -0.1808, -0.3459],\n",
       "        [ 0.9343, -0.1746,  0.7430],\n",
       "        [-0.4747, -0.9394, -0.3877]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mean_agg(H=example_message_passing_output, batch=which_atoms_in_which_molecule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.9187, -0.3616, -0.6917],\n",
       "        [ 3.7373, -0.6986,  2.9720],\n",
       "        [-0.4747, -0.9394, -0.3877]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum_agg(H=example_message_passing_output, batch=which_atoms_in_which_molecule)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Norm aggregation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Norm aggregation can be better than sum aggregation when the molecules are large as it is best to keep the hidden representation values on the order of 1 (though this is less important when batch normalization is used). The normalization constant can be customized (defaults to 100.0)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "norm_agg = NormAggregation()\n",
    "big_norm = NormAggregation(norm=1000.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0092, -0.0036, -0.0069],\n",
       "        [ 0.0374, -0.0070,  0.0297],\n",
       "        [-0.0047, -0.0094, -0.0039]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "norm_agg(H=example_message_passing_output, batch=which_atoms_in_which_molecule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0009, -0.0004, -0.0007],\n",
       "        [ 0.0037, -0.0007,  0.0030],\n",
       "        [-0.0005, -0.0009, -0.0004]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "big_norm(H=example_message_passing_output, batch=which_atoms_in_which_molecule)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attentive aggregation "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This uses a learned weighted average to combine atom representations within a molecule graph. It needs to be told the size of the hidden dimension as it uses the hidden representation of each atom to calculate the weight of that atom. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "att_agg = AttentiveAggregation(output_size=hidden_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.4551, -0.1791, -0.3438],\n",
       "        [ 0.9370,  0.1375,  0.3714],\n",
       "        [-0.4747, -0.9394, -0.3877]], grad_fn=<ScatterReduceBackward0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "att_agg(H=example_message_passing_output, batch=which_atoms_in_which_molecule)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
